import jsonlines
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
from peft import PeftModel
from argparse import ArgumentParser
import torch
from tqdm import tqdm
import os
from tools import DynamicDataset, read_test_jsonl
from torch.utils.data import DataLoader


parser = ArgumentParser()

parser.add_argument('--data_dir', type=str, default='./data/gpt-4')
parser.add_argument('--base_model', type=str, default='')
parser.add_argument('--lora_dir', type=str, default='./lora/gpt-4')
parser.add_argument('--output_dir', type=str, default='./output/')
parser.add_argument("--batch_size", type=int, default=48)
parser.add_argument("--max_length", type=int, default=100)
parser.add_argument("--data_name", type=str, default="test")

opt = parser.parse_args()


tokenizer = LlamaTokenizer.from_pretrained(opt.base_model)
tokenizer.pad_token_id = (0)
tokenizer.padding_side = 'left'

model = LlamaForCausalLM.from_pretrained(
    opt.base_model,
    load_in_8bit=False,
    # torch_dtype=torch.float16,
    device_map="auto",
)

model = PeftModel.from_pretrained(
    model,
    opt.lora_dir
)

data = read_test_jsonl(f"{opt.data_dir}/{opt.data_name}.jsonl")
if opt.data_name == "test":
    data = [d[:1000] for d in data]
dataset = DynamicDataset(*data)
loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False)

config = GenerationConfig(
    temperature=0.001,
    top_p=0.75,
    do_sample=True
)

system = "You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You are an expert in answer my questions. Your response should follow the format like \"Explanation: ___\nAnswer: ___\"."
fo = jsonlines.open(f"{opt.output_dir}/{opt.data_name}.jsonl", 'w')

# for d in tqdm(data):
for batch in tqdm(loader):
    if opt.data_name == "test":
        inputs, outputs, systems = batch
        input_text = [f"<|im_start|>system\n{sys}<|im_end|>\n<|im_start|>user\n{ipt}<|im_end|>" for ipt, sys in zip(inputs, systems)]
    else:
        inputs, outputs = batch
        input_text = [f"<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{ipt}<|im_end|>" for ipt in inputs]
    
    inputs_t = tokenizer(input_text, padding=True, return_tensors='pt')
    input_ids = inputs_t.input_ids.to('cuda')
    attention_mask = inputs_t.attention_mask.to("cuda")

    with torch.no_grad():
        generated_response = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            generation_config=config,
            return_dict_in_generate=True,
            max_new_tokens=250
        )
    
    sentence_ids = generated_response.sequences
    sentences = tokenizer.batch_decode(sentence_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    for ipt, ot, sent in zip(inputs, outputs, sentences):
        fo.write({"input": ipt, "output": ot, "R": sent})

    # sentence_ids = generated_response.sequences[0]
    # sentences = tokenizer.decode(sentence_ids)
    # d['R'] = sentences
    # fo.write(d)
    
    
    
fo.close()






